package com.lw.leetcode.hash.b;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
 * Created with IntelliJ IDEA.
 * hash
 * 1577. 数的平方等于两数乘积的方法数
 *
 * @author liw
 * @version 1.0
 * @date 2021/12/16 22:26
 */
public class NumTriplets {


    public static void main(String[] args) {
        NumTriplets test = new NumTriplets();

        // 1
//        int[] arr = {7,4};
//        int[]arr2 = {5,2,8,9};
        // 9
//        int[] arr = {1,1};
//        int[]arr2 = {1,1,1};
        // 0
        int[] arr = {4, 7, 9, 11, 23};
        int[] arr2 = {3, 5, 1024, 12, 18};

        int i = test.numTriplets(arr, arr2);
        System.out.println(i);
    }

    public int numTriplets(int[] nums1, int[] nums2) {
        Map<Long, Long> map1 = new HashMap<>();
        Map<Long, Long> map2 = new HashMap<>();
        for (long i : nums1) {
            map1.merge(i, 1L, (a, b) -> a + b);
        }
        for (long i : nums2) {
            map2.merge(i, 1L, (a, b) -> a + b);
        }
        int sum = 0;
        sum += find(map1, map2);
        sum += find(map2, map1);
        return sum;
    }


    private int find(Map<Long, Long> map1, Map<Long, Long> map2) {
        int sum1 = 0;
        int sum2 = 0;
        Set<Long> set = map2.keySet();
        for (Map.Entry<Long, Long> entry : map1.entrySet()) {
            Long key = entry.getKey();
            Long value = entry.getValue();
            long i = key * key;
            for (Long v2 : set) {
                if (i % v2 == 0) {
                    long v3 = i / v2;
                    Long c2 = map2.get(v3);
                    if (c2 != null) {
                        if (v2.equals(v3)) {
                            sum1 += ((c2 * (c2 - 1)) >> 1) * value;
                        } else {
                            sum2 += map2.get(v2) * c2 * value;
                        }
                    }
                }
            }
        }
        return sum1 + (sum2 >> 1);
    }

}
